{-# LANGUAGE DeriveDataTypeable #-}

-- | A type-safe list that has at least one element.
module Data.List.NonEmpty
  (
    NonEmpty
    -- * Accessors
  , neHead
  , neTail
    -- * Constructors
  , nonEmpty
  , (|:)
  , toNonEmpty
  , toNonEmpty'
  , unsafeToNonEmpty
  , (.:)
    -- * List functions
  , reverse
  , scanl
  , scanl1
  , scanr
  , scanr1
  , iterate
  , cycle
  , inits
  , tails
  , sort
  , insert
  , unzip
  ) where

import Control.Applicative
import Control.Monad
import Control.Comonad
import Control.Functor.Zip
import Control.Arrow
import Data.Foldable
import Data.Maybe
import Data.Traversable
import Data.Typeable (Typeable)
import Data.Data (Data)
import Data.Semigroup
import qualified Data.List as L
import Prelude hiding (foldr, reverse, scanl, scanl1, scanr, scanr1, iterate, repeat, cycle, unzip)
import Test.QuickCheck hiding (NonEmpty)

-- | A list with at least one element.
data NonEmpty a = NonEmpty {
  neHead :: a, -- ^ The head of the non-empty list.
  neTail :: [a] -- ^ The tail of the non-empty list.
} deriving ( Eq
           , Ord
           , Typeable
           , Data)

instance Functor NonEmpty where
  fmap f (NonEmpty h t) = NonEmpty (f h) (fmap f t)

instance Pointed NonEmpty where
  point = return

instance Copointed NonEmpty where
  extract = neHead

instance Applicative NonEmpty where
  pure = return
  (<*>) = ap

instance Monad NonEmpty where
  return = flip NonEmpty []
  NonEmpty h t >>= f = let NonEmpty a b = f h
                           k = t >>= toList . f
                       in NonEmpty a (b ++ k)

instance Comonad NonEmpty where
  duplicate x@(NonEmpty _ t) = NonEmpty x (case toNonEmpty t of Nothing -> []
                                                                Just u  -> toList (duplicate u))

instance Foldable NonEmpty where
  foldr f x (NonEmpty h t) = f h (foldr f x t)
  foldl f x (NonEmpty h t) = foldl' f x (h:t)

instance Traversable NonEmpty where
  traverse f a = NonEmpty <$> head <*> tail <$> traverse f (toList a)

instance (Show a) => Show (NonEmpty a) where
  show (NonEmpty h t) = '|' : show (h:t) ++ "|"

instance Semigroup (NonEmpty a) where
  NonEmpty a b .++. NonEmpty c d = NonEmpty a (b ++ c:d)

instance Zip NonEmpty where
  fzip = list2 zip

instance (Arbitrary a) => Arbitrary (NonEmpty a) where
  arbitrary = nonEmpty <$> arbitrary <*> arbitrary
  shrink = (unsafeToNonEmpty <$>) . shrink . toList

-- | Constructs a non-empty list with the given head and tail.
nonEmpty :: a -- ^ The head.
            -> [a] -- ^ The tail.
            -> NonEmpty a
nonEmpty = NonEmpty

-- | Constructs a non-empty list with the given head and tail (an alias for @nonEmpty@).
(|:) :: a -- ^ The head.
        -> [a] -- ^ The tail.
        -> NonEmpty a
(|:) = nonEmpty

-- | Tries to convert a list to a @NonEmpty@ returning @Nothing@ if the given list is empty.
toNonEmpty :: [a] -- ^ The list to convert.
              -> Maybe (NonEmpty a)
toNonEmpty [] = Nothing
toNonEmpty (h:t) = Just (NonEmpty h t)

-- | Converts a list to a @NonEmpty@ using the given default value for the empty list case.
toNonEmpty' :: NonEmpty a -- ^ The default return value if the given list is empty.
               -> [a] -- ^ The list to convert.
               -> NonEmpty a
toNonEmpty' d = fromMaybe d . toNonEmpty

-- | /WARNING: Fails if given the empty list./
-- Tries to convert a list to a @NonEmpty@.
unsafeToNonEmpty :: [a] -- ^ The list to convert (must not be empty).
                    -> NonEmpty a
unsafeToNonEmpty = toNonEmpty' (error "unsafeToNonEmpty on empty list")

-- | Prepends a value to a non-empty list.
(.:) :: a -- ^ The value to prepend.
        -> NonEmpty a -- ^ The non-empty list to prepend to.
        -> NonEmpty a
a .: NonEmpty h t = NonEmpty a (h:t)

-- | Reverses the elements of the (finite) non-empty list.
reverse :: NonEmpty a -- ^ A finite non-empty l
           -> NonEmpty a
reverse = list L.reverse

scanl :: (b -> a -> b)
         -> b
         -> NonEmpty a
         -> NonEmpty b
scanl = (list .) . L.scanl

scanl1 :: (a -> a -> a)
          -> NonEmpty a
          -> NonEmpty a
scanl1 = list . L.scanl1

scanr :: (a -> b -> b)
         -> b
         -> NonEmpty a
         -> NonEmpty b
scanr = (list .) . L.scanr

scanr1 :: (a -> a -> a)
          -> NonEmpty a
          -> NonEmpty a
scanr1 = list . L.scanr1

iterate :: (a -> a)
           -> a
           -> NonEmpty a
iterate = (unsafeToNonEmpty .) . L.iterate

cycle :: (Foldable f) =>
         f a
         -> NonEmpty a
cycle = list L.cycle

inits :: [a]
         -> NonEmpty [a]
inits = unsafeToNonEmpty . L.inits

tails :: [a]
         -> NonEmpty [a]
tails = unsafeToNonEmpty . L.tails

sort :: (Ord a) =>
        NonEmpty a
        -> NonEmpty a
sort = list L.sort

insert :: (Ord a) =>
          a
          -> NonEmpty a
          -> NonEmpty a
insert a = unsafeToNonEmpty . L.insert a . toList

unzip :: NonEmpty (a, b)
         -> (NonEmpty a, NonEmpty b)
unzip = (unsafeToNonEmpty *** unsafeToNonEmpty) . L.unzip . toList

------------------
-- Not exported --
------------------

list :: Foldable f =>
        ([a] -> [b])
        -> f a
        -> NonEmpty b
list = (unsafeToNonEmpty .) . (. toList)

list2 :: Foldable f =>
         ([a] -> [b] -> [c])
         -> f a
         -> f b
         -> NonEmpty c
list2 f a b = unsafeToNonEmpty (f (toList a) (toList b))
